import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np

# Style 
plt.style.use('ggplot')
mpl.rcParams['axes.edgecolor'] = 'black'
mpl.rcParams['axes.linewidth'] = 1.2

# Fonction pour calculer les quantités à un avancement donné
def quantites(n1, n2, x):
    return {
        'I₂': max(n1 - x, 0),
        'S₂O₃²⁻': max(n2 - 2 * x, 0),
        'I⁻': 2 * x,
        'S₄O₆²⁻': x,
    }

# Couleurs des espèces chimiques
couleurs = {
    'I₂': '#00bcd4',
    'S₂O₃²⁻': '#4caf50',
    'I⁻': '#f48fb1',
    'S₄O₆²⁻': '#ffcc80',
}

# Fonction principale d'affichage
def plot_state(n1, n2, x):
    x_max = min(n1, n2 / 2)
    erreur = x > x_max
    if erreur:
        print("⚠️ Avancement trop grand : la réaction est terminée.")
        x = x_max

    etat = quantites(n1, n2, x)
    species = ['I₂', 'S₂O₃²⁻', 'I⁻', 'S₄O₆²⁻']
    positions = np.arange(len(species))
    quantites_initiales = [n1, n2, 0, 0]
    quantites_actuelles = [etat[s] for s in species]
    y_disappear = [quantites_initiales[i] - quantites_actuelles[i] for i in range(len(species))]

    plt.figure(figsize=(10, 6))

    # Texte au-dessus
    plt.text(
        1.5, 1.15 * max(n1, n2),
        f"n₀(I₂) = {n1:.2f} mmol   |   n₀(S₂O₃²⁻) = {n2:.2f} mmol   |   x = {x:.2f} mmol",
        ha='center', fontsize=13, weight='bold'
    )

    # Barres
    plt.bar(positions, quantites_initiales, width=0.5, color='lightgray', edgecolor='black', alpha=0.5)
    plt.bar(positions, y_disappear, width=0.5, bottom=quantites_actuelles, color=['#d0f0fd', '#b2fab4', 'lightgray', 'lightgray'], edgecolor='black')
    plt.bar(positions, quantites_actuelles, width=0.5, color=[couleurs[s] for s in species], edgecolor='black')

    # Flèches Δn
    for pos, s, n_init in zip(positions, species, quantites_initiales):
        n_final = etat[s]
        delta = n_final - n_init
        if n_init != n_final:
            y_start = max(n_init, n_final)
            y_end = min(n_init, n_final)
            y_middle = (n_init + n_final) / 2
            plt.annotate('', xy=(pos, y_end), xytext=(pos, y_start), arrowprops=dict(arrowstyle='<->', color='black', lw=1.5))
            signe = '+' if delta > 0 else '−'
            plt.text(pos + 0.1, y_middle, f"Δn = {signe}{abs(delta):.2f} mmol", ha='left', va='center', fontsize=10)

    # Infos en dessous
    plt.figtext(0.5, -0.12,
        f"Quantités actuelles : n(I₂) = {etat['I₂']:.2f} mmol, "
        f"n(S₂O₃²⁻) = {etat['S₂O₃²⁻']:.2f} mmol, "
        f"n(I⁻) = {etat['I⁻']:.2f} mmol, "
        f"n(S₄O₆²⁻) = {etat['S₄O₆²⁻']:.2f} mmol",
        ha='center', fontsize=12, weight='bold'
    )

    if erreur:
        plt.figtext(0.5, -0.08,
            f"Avancement trop grand ! xmax = {x_max:.2f} mmol",
            ha='center', fontsize=13, color='red', weight='bold'
        )

    plt.xticks(positions, species, fontsize=12)
    plt.ylabel("Quantité de matière (mmol)", fontsize=12)
    plt.title("Évolution des quantités de matière", fontsize=14, weight='bold')
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.ylim(0, 1.3 * max(n1, n2))
    plt.tight_layout()
    plt.show()
    plt.close()

# ---- TEST ----
# Exemple d'appel
if __name__ == "__main__":
    # Tu peux modifier ces valeurs pour faire différents tests
    n1 = 2.5    # mmol
    n2 = 6.0    # mmol
    x = 2.0     # mmol
    plot_state(n1, n2, x)

if __name__ == "__main__":
    print("Simulation de l'évolution des quantités de matière dans la réaction :")
    print("2 S₂O₃²⁻ + I₂ → 2 I⁻ + S₄O₆²⁻")
    print("Entrez les données en mmol. Tapez 'q' pour quitter.\n")

    while True:
        try:
            n1_input = input("n₀(I₂) (mmol) = ")
            if n1_input.lower() == 'q':
                break
            n1 = float(n1_input)

            n2_input = input("n₀(S₂O₃²⁻) (mmol) = ")
            if n2_input.lower() == 'q':
                break
            n2 = float(n2_input)

            x_input = input("Avancement x (mmol) = ")
            if x_input.lower() == 'q':
                break
            x = float(x_input)

            plot_state(n1, n2, x)
            print("\n--- Nouvelle simulation ---\n")

        except ValueError:
            print("Entrée invalide. Veuillez entrer un nombre réel ou 'q' pour quitter.\n")

